{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tf845nO8_-hV"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "EEczi3q7AOXD"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7InkT1mdAJhT"
      },
      "source": [
        "# Gemma - Model evaluation\n",
        "\n",
        "This notebook demonstrates using EleautherAI's Language Model Evaluation Harness to perform model performance benchmark on Gemma2 2B, specifically using a subset of MMLU.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]evaluation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKQ4bH7qMGrA"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use the T4 GPU but you need a high RAM instance (due to [this](https://github.com/google-deepmind/gemma/issues/57)).\n",
        "\n",
        "### Gemma setup\n",
        "\n",
        "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on kaggle.com.\n",
        "* Select a Colab runtime with sufficient resources to run\n",
        "  the Gemma 2B model.\n",
        "* Generate and configure a Kaggle username and an API key as Colab secrets.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KJintwsi1o87"
      },
      "source": [
        "### Configure your credentials\n",
        "\n",
        "Add your Kaggle credentials to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`\n",
        "3. Copy/paste your username into `KAGGLE_USERNAME`\n",
        "3. Copy/paste your key into `KAGGLE_KEY`\n",
        "4. Toggle the buttons on the left to allow notebook access to the secrets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ATbyLmuImHTA"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import kagglehub\n",
        "from google.colab import userdata\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
        "\n",
        "# Pre-allocate 90% of accelerator memory\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.9\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KxDcs0t1A7zX"
      },
      "source": [
        "Install the evaluation harness."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bzim2HVVbuy-"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting git+https://github.com/EleutherAI/lm-evaluation-harness.git\n",
            "  Cloning https://github.com/EleutherAI/lm-evaluation-harness.git to /tmp/pip-req-build-1s_z1sm5\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/EleutherAI/lm-evaluation-harness.git /tmp/pip-req-build-1s_z1sm5\n",
            "  Resolved https://github.com/EleutherAI/lm-evaluation-harness.git to commit 543617fef9ba885e87f8db8930fbbff1d4e2ca49\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: accelerate>=0.26.0 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.33.0)\n",
            "Requirement already satisfied: evaluate in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.4.2)\n",
            "Requirement already satisfied: datasets>=2.16.0 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.21.0)\n",
            "Requirement already satisfied: jsonlines in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (4.0.0)\n",
            "Requirement already satisfied: numexpr in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.10.1)\n",
            "Requirement already satisfied: peft>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.12.0)\n",
            "Requirement already satisfied: pybind11>=2.6.2 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.13.5)\n",
            "Requirement already satisfied: pytablewriter in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (1.2.0)\n",
            "Requirement already satisfied: rouge-score>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.1.2)\n",
            "Requirement already satisfied: sacrebleu>=1.5.0 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.4.3)\n",
            "Requirement already satisfied: scikit-learn>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (1.3.2)\n",
            "Requirement already satisfied: sqlitedict in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.1.0)\n",
            "Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (2.4.0+cu121)\n",
            "Requirement already satisfied: tqdm-multiprocess in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.0.11)\n",
            "Requirement already satisfied: transformers>=4.1 in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (4.44.2)\n",
            "Requirement already satisfied: zstandard in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.23.0)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (0.3.8)\n",
            "Requirement already satisfied: word2number in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (1.1)\n",
            "Requirement already satisfied: more-itertools in /usr/local/lib/python3.10/dist-packages (from lm_eval==0.4.4) (10.3.0)\n",
            "Requirement already satisfied: numpy<2.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (1.26.4)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (24.1)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (5.9.5)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (6.0.2)\n",
            "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (0.24.6)\n",
            "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm_eval==0.4.4) (0.4.4)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (3.15.4)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (17.0.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (2.1.4)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (2.32.3)\n",
            "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (4.66.5)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (3.5.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (0.70.16)\n",
            "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets>=2.16.0->lm_eval==0.4.4) (2024.6.1)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->lm_eval==0.4.4) (3.10.5)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm_eval==0.4.4) (2.1.0)\n",
            "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm_eval==0.4.4) (3.8.1)\n",
            "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm_eval==0.4.4) (1.16.0)\n",
            "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm_eval==0.4.4) (2.10.1)\n",
            "Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm_eval==0.4.4) (2024.5.15)\n",
            "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm_eval==0.4.4) (0.9.0)\n",
            "Requirement already satisfied: colorama in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm_eval==0.4.4) (0.4.6)\n",
            "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm_eval==0.4.4) (4.9.4)\n",
            "Requirement already satisfied: scipy>=1.5.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm_eval==0.4.4) (1.13.1)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm_eval==0.4.4) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm_eval==0.4.4) (3.5.0)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm_eval==0.4.4) (4.12.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm_eval==0.4.4) (1.13.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm_eval==0.4.4) (3.3)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm_eval==0.4.4) (3.1.4)\n",
            "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.1->lm_eval==0.4.4) (0.19.1)\n",
            "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonlines->lm_eval==0.4.4) (24.2.0)\n",
            "Requirement already satisfied: setuptools>=38.3.0 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (71.0.4)\n",
            "Requirement already satisfied: DataProperty<2,>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (1.0.1)\n",
            "Requirement already satisfied: mbstrdecoder<2,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (1.1.3)\n",
            "Requirement already satisfied: pathvalidate<4,>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (3.2.1)\n",
            "Requirement already satisfied: tabledata<2,>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (1.3.3)\n",
            "Requirement already satisfied: tcolorpy<1,>=0.0.5 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm_eval==0.4.4) (0.1.6)\n",
            "Requirement already satisfied: typepy<2,>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm_eval==0.4.4) (1.3.2)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (2.4.0)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (1.3.1)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (1.4.1)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (6.0.5)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (1.9.4)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->lm_eval==0.4.4) (4.0.3)\n",
            "Requirement already satisfied: chardet<6,>=3.0.4 in /usr/local/lib/python3.10/dist-packages (from mbstrdecoder<2,>=1.0.0->pytablewriter->lm_eval==0.4.4) (5.2.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->lm_eval==0.4.4) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->lm_eval==0.4.4) (3.8)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->lm_eval==0.4.4) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->lm_eval==0.4.4) (2024.8.30)\n",
            "Requirement already satisfied: python-dateutil<3.0.0,>=2.8.0 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm_eval==0.4.4) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2018.9 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm_eval==0.4.4) (2024.1)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8->lm_eval==0.4.4) (2.1.5)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge-score>=0.0.4->lm_eval==0.4.4) (8.1.7)\n",
            "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.0->lm_eval==0.4.4) (2024.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.8->lm_eval==0.4.4) (1.3.0)\n"
          ]
        }
      ],
      "source": [
        "!pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hzvwo9Is7mvX"
      },
      "source": [
        "Install the Gemma JAX library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5EUxBOYImMc1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install -q git+https://github.com/google-deepmind/gemma.git\n",
        "from gemma import params as params_lib\n",
        "import sentencepiece as spm\n",
        "from gemma import transformer as transformer_lib\n",
        "from gemma import sampler as sampler_lib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KNpy8QsB8aTD"
      },
      "source": [
        "Download the Gemma model and tokenizer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ohqi-FOqmRA8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.3.0)\n"
          ]
        }
      ],
      "source": [
        "GEMMA_VARIANT = 'gemma2-2b-it'\n",
        "GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')\n",
        "CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)\n",
        "TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8NM7ALv5E9sh"
      },
      "outputs": [],
      "source": [
        "params = params_lib.load_and_format_params(CKPT_PATH)\n",
        "vocab = spm.SentencePieceProcessor()\n",
        "vocab.Load(TOKENIZER_PATH)\n",
        "\n",
        "transformer_config = transformer_lib.TransformerConfig.from_params(\n",
        "    params=params,\n",
        "    cache_size=1024\n",
        ")\n",
        "\n",
        "transformer = transformer_lib.Transformer(transformer_config)\n",
        "\n",
        "sampler = sampler_lib.Sampler(\n",
        "    transformer=transformer,\n",
        "    vocab=vocab,\n",
        "    params=params['transformer'],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O4icXT5G9UEZ"
      },
      "source": [
        "Create a new Gemma LM calss, following the [New Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md). For MMLU, we only need to implement the loglikelihood() function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RV86s4JVtVCi"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from typing import Optional, Union, List, Dict\n",
        "import kagglehub\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jinja2 import Template\n",
        "from tqdm import tqdm\n",
        "import lm_eval\n",
        "from lm_eval.api.instance import Instance\n",
        "from lm_eval.api.model import LM\n",
        "from lm_eval.api.registry import register_model\n",
        "from lm_eval.models import utils\n",
        "\n",
        "\n",
        "@register_model(\"Gemma2\")\n",
        "class Gemma2LM(LM):\n",
        "\n",
        "    def __init__(self, batch_size: Optional[int] = 1):\n",
        "        self._batch_size = batch_size\n",
        "        self._rank = 0\n",
        "        self._world_size = 1\n",
        "\n",
        "    @property\n",
        "    def batch_size(self):\n",
        "        return self._batch_size\n",
        "\n",
        "    def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:\n",
        "        results = []\n",
        "\n",
        "        for chunked_request in utils.chunks(tqdm(requests, disable=False), self._batch_size):\n",
        "            contexts = [req.args[0] for req in chunked_request]\n",
        "            next_tokens = [req.args[1] for req in chunked_request]\n",
        "\n",
        "            outputs = sampler(input_strings=contexts, total_generation_steps=1)\n",
        "\n",
        "            for i, logits in enumerate(outputs.logits):\n",
        "                next_token = next_tokens[i]\n",
        "\n",
        "                next_token_id = vocab.EncodeAsIds(next_token)[0]\n",
        "                logits = jnp.array(logits[0])  # Assuming generating one token\n",
        "\n",
        "                log_probs = jax.nn.log_softmax(logits, axis=-1)\n",
        "\n",
        "                next_token_logprob = log_probs[next_token_id]\n",
        "\n",
        "                is_greedy = next_token_id == log_probs.argmax()\n",
        "\n",
        "                results.append((float(next_token_logprob), is_greedy))\n",
        "\n",
        "        return results\n",
        "\n",
        "    def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:\n",
        "        # Used to evaluate perplexity; not important for this tutorial\n",
        "        raise NotImplementedError(\"loglikelihood_rolling is not implemented\")\n",
        "\n",
        "\n",
        "    def generate_until(self, requests: list[Instance]) -> list[str]:\n",
        "        # Not used for MMLU\n",
        "        raise NotImplementedError(\"generat_until is not implemented\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SfTlmr7t9JMA"
      },
      "source": [
        "Now start the evaluation process. It takes a long time to run through the entire MMLU benchmark, so we are just going to run a subset of it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8KZBGmcAuFKk"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234\n",
            "INFO:lm-eval:Using pre-initialized model\n",
            "INFO:lm-eval:`group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.\n",
            "WARNING:lm-eval:Overwriting default num_fewshot of mmlu_management from None to 3\n",
            "INFO:lm-eval:Setting fewshot random generator seed to 1234\n",
            "INFO:lm-eval:Building contexts for mmlu_management on rank 0...\n",
            "100%|██████████| 103/103 [00:00<00:00, 118.20it/s]\n",
            "INFO:lm-eval:Running loglikelihood requests\n",
            "100%|██████████| 412/412 [55:42<00:00,  8.11s/it]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'mmlu_management': {'alias': 'management', 'acc,none': 0.6601941747572816, 'acc_stderr,none': 0.046897659372781335}}\n"
          ]
        }
      ],
      "source": [
        "import lm_eval\n",
        "results = lm_eval.simple_evaluate(\n",
        "    model=Gemma2LM(batch_size=4),\n",
        "    tasks=[\"mmlu_management\"],\n",
        "    num_fewshot=3,\n",
        ")\n",
        "\n",
        "print(results['results'])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]evaluation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
